import scipy.io as sio
import torch
import matplotlib.pyplot as plt
import csv
import imshift
from class_Angular_Spectrum import Angular_Spectrum
import os
import time

#device
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

start_time = time.time()  # start time

#image size
N1 = 512
N2 = 512

#ptychography parameters
pxsize = 50e-3  #mm
wavlen = 116e-3 #mm
dist = 20 #mm


# object translation positions
K1 = 8
K2 = 8
K = K1*K2
step = 1.2
shifts_1_range = torch.linspace(-step*(K1-1)/2, step*(K1-1)/2, K1)
shifts_2_range = torch.linspace(-step*(K2-1)/2, step*(K2-1)/2, K2)
shifts_1, shifts_2 = torch.meshgrid(shifts_1_range, shifts_2_range, indexing='ij')
shifts_1=shifts_1.t()
shifts_2=shifts_2.t()

pos_error = 0.005   #estimation error for the probe positions (mm)
perturbs_1 = pos_error*torch.rand_like(shifts_1) - pos_error/2
perturbs_2 = pos_error*torch.rand_like(shifts_2) - pos_error/2
shifts_1 = shifts_1 + perturbs_1
shifts_2 = shifts_2 + perturbs_2

shifts_1=shifts_1.reshape(K,1).to(device)
shifts_2=shifts_2.reshape(K,1).to(device)

#Read data
mat_contents = sio.loadmat('image3_8x8.mat')
y = mat_contents['tensor']
y = torch.from_numpy(y).to(device)
y = y.double()
y = torch.complex(y, torch.zeros_like(y, dtype=torch.float64))

#probe initialization
probe_est = sio.loadmat('probe_est_initial.mat')
probe_est = probe_est['probe_est']
probe_est = torch.from_numpy(probe_est).to(device)
imag_part = torch.zeros((512,512),dtype=torch.float64).to(device)
probe_est = torch.complex(probe_est,imag_part)

#object initialization
obj_est = torch.ones((N1,N2),dtype=torch.complex128).to(device)

#running iterations
n_iters = 5000

#Angular Spectrum Method
ASM1=Angular_Spectrum(wavlen, N1, pxsize, dist)
ASM2=Angular_Spectrum(wavlen, N1, pxsize, -dist)

dir_save = './results/DM/'
os.makedirs(dir_save, exist_ok=True)
dir_save1 = dir_save + '1cycle/'
os.makedirs(dir_save1, exist_ok=True)
dir_save2 = dir_save + '100cycle/'
os.makedirs(dir_save2, exist_ok=True)

header = ['no','cycle time','total time','read time']
data_read_time = time.time()-start_time


#mask
mask = torch.zeros((512, 512), dtype=torch.double).to(device)
mask[128:384, 128:384,] = 1.0

#DM algorithms
from tqdm import tqdm
file_time = 0
EWs = torch.zeros((K,N1,N2),dtype=torch.complex128).to(device) + probe_est
c = 1e-10
J = K

for cycle in tqdm(range(1, n_iters + 1)):

    iter_start = time.time()

    JJ = torch.randperm(K)
    ######################################################
    for idx in range(J):
        j = JJ[idx].item()
        tempEW = probe_est * imshift.imshift(mask * obj_est, shifts_1[j] / pxsize, shifts_2[j] / pxsize,device)
        tempEW0 = ASM1(2*tempEW - EWs[j,:,:])
        reviseEW = ASM2(torch.sqrt(y[j,:,:]) * torch.exp(1j * torch.angle(tempEW0)))
        EWs[j,:,:] = EWs[j,:,:] +reviseEW - tempEW


    ######################################################
    numP = torch.zeros((N1,N2),dtype=torch.complex128).to(device)
    denP = torch.zeros((N1,N2),dtype=torch.complex128).to(device)

    for idx in range(J):
        j = JJ[idx].item()
        numP = numP + torch.conj(imshift.imshift(mask *obj_est, shifts_1[j] / pxsize, shifts_2[j] / pxsize,device)) * EWs[j,:,:]
        denP = denP + torch.abs(imshift.imshift(mask *obj_est, shifts_1[j] / pxsize, shifts_2[j] / pxsize,device)) ** 2

    probe_est = numP / (denP + c)

    abs_probe = torch.abs(probe_est)
    tooHigh_pro = abs_probe > 2
    probe_est[tooHigh_pro] = 2 * torch.exp(1j * torch.angle(probe_est[tooHigh_pro]))

    ######################################################
    numO = torch.zeros((N1,N2),dtype=torch.complex128).to(device)
    denO = torch.zeros((N1,N2),dtype=torch.complex128).to(device)


    for idx in range(J):
        j = JJ[idx].item()
        numO = numO + torch.conj(imshift.imshift(mask *probe_est, -shifts_1[j] / pxsize, -shifts_2[j] / pxsize,device)) * imshift.imshift(EWs[j,:,:], -shifts_1[j] / pxsize, -shifts_2[j] / pxsize,device)
        denO = denO +  torch.abs(imshift.imshift(mask *probe_est, -shifts_1[j] / pxsize, -shifts_2[j] / pxsize,device)) ** 2

    obj_est = numO / (denO + c)

    abs_obj = torch.abs(obj_est)
    tooHigh = abs_obj > 2
    obj_est[tooHigh] = 2 * torch.exp(1j * torch.angle(obj_est[tooHigh]))



    #save data
    if (cycle <= 100):
        amp_np = torch.abs(obj_est).cpu().numpy()
        pha_np = torch.angle(obj_est).cpu().numpy()
        amp_np_pro = torch.abs(probe_est).cpu().numpy()
        pha_np_pro = torch.angle(probe_est).cpu().numpy()
        filename = dir_save1 + 'DM' +  '_cycle' + (str(cycle).zfill(6)) + '_results.mat'
        sio.savemat(filename, {'output_obj_A': amp_np, 'output_obj_P': pha_np,
                               'output_pro_A': amp_np_pro, 'output_pro_P': pha_np_pro})


    iter_time = time.time() - iter_start
    file_time += iter_time
    values = [cycle, iter_time ,file_time,data_read_time]

    path_csv = dir_save + 'DM' + '_sample' + '.csv'
    if os.path.isfile(path_csv) == False:
        file = open(path_csv, 'w', newline='')
        writer = csv.writer(file)
        writer.writerow(header)
        writer.writerow(values)
    else:
        file = open(path_csv, 'a', newline='')
        writer = csv.writer(file)
        writer.writerow(values)
    file.close()

    if (cycle == 1) or (cycle == 10) or (cycle == 50) or ((cycle) % 100 == 0):
        plt.subplot(2, 4, 1)
        amp_np = torch.abs(obj_est[129:384,129:384]).cpu().numpy()
        plt.imshow(amp_np, cmap='gray')
        plt.title('lable (amp)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 2)
        pha_np = torch.angle(obj_est[129:384,129:384]).cpu().numpy()
        plt.imshow(pha_np, cmap='inferno')
        plt.title('lable (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 3)
        amp_np = torch.abs(obj_est).cpu().numpy()
        plt.imshow(amp_np, cmap='gray')
        plt.title('lable (amp)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 4)
        pha_np = torch.angle(obj_est).cpu().numpy()
        plt.imshow(pha_np, cmap='inferno')
        plt.title('lable (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')

        plt.subplot(2, 4, 5)
        amp_np_pro = torch.abs(probe_est[129:384,129:384]).cpu().numpy()
        plt.imshow(amp_np_pro, cmap='gray')
        plt.title('cycle' + (str(cycle).zfill(6)))
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 6)
        pha_np_pro = torch.angle(probe_est[129:384,129:384]).cpu().numpy()
        plt.imshow(pha_np_pro, cmap='inferno')
        # plt.title('probe (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 7)
        amp_np_pro = torch.abs(probe_est).cpu().numpy()
        plt.imshow(amp_np_pro, cmap='gray')
        plt.title('cycle' + (str(cycle).zfill(6)))
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 8)
        pha_np_pro = torch.angle(probe_est).cpu().numpy()
        plt.imshow(pha_np_pro, cmap='inferno')
        # plt.title('probe (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')

        filename_plt = dir_save2 + 'DM' + '_cycle' + (
            str(cycle).zfill(6)) + '_results.png'
        plt.savefig(filename_plt)
        plt.show()


        amp_np = torch.abs(obj_est).cpu().numpy()
        pha_np = torch.angle(obj_est).cpu().numpy()
        amp_np_pro = torch.abs(probe_est).cpu().numpy()
        pha_np_pro = torch.angle(probe_est).cpu().numpy()
        filename = dir_save2 + 'DM' +  '_cycle' + (str(cycle).zfill(6)) + '_results.mat'
        sio.savemat(filename, {'output_obj_A': amp_np, 'output_obj_P': pha_np,
                               'output_pro_A': amp_np_pro, 'output_pro_P': pha_np_pro})


